"""Box‑counting dimension estimation.

This module implements a simple box‑counting estimator for the
fractal dimension of a point set.  Given a set of points in ``d``
dimensions and a scale parameter ``n``, the algorithm divides each
dimension into ``2**n`` equal bins and counts how many of the resulting
hyper‑rectangular boxes contain at least one point.  The estimated
box‑counting dimension at scale ``n`` is then

    D(n) = log(N_occ) / (n * log(2))

where ``N_occ`` is the number of occupied boxes.  This formulation
assumes the bounding box of the data has finite extent and therefore
the box size scales as ``r ∝ 2^{-n}``.  The constant factor
associated with the size of the bounding box cancels out in the
dimension estimate.

While more sophisticated estimators exist (e.g. correlation
dimension, variation of information), this simple approach suffices
for demonstrating the logistic pivot fitting described in the
pipeline.
"""

from __future__ import annotations

import numpy as np
from typing import Iterable, Tuple

def box_count_dimension(points: np.ndarray, scales: Iterable[int] = range(1, 8)) -> Tuple[np.ndarray, np.ndarray]:
    """Estimate box‑counting dimensions over multiple scales.

    Parameters
    ----------
    points : np.ndarray
        A 2D array of shape ``(n_samples, n_features)`` representing
        coordinates of points in Euclidean space.  The array should
        contain only numeric values.
    scales : Iterable[int], optional
        Iterable of integer scales ``n`` at which to estimate the
        dimension.  Each scale corresponds to dividing each axis into
        ``2**n`` bins.  By default, uses the range ``1 .. 7``.

    Returns
    -------
    n_vals : np.ndarray
        1D array of the integer scales provided.
    D_vals : np.ndarray
        1D array containing the estimated box‑counting dimension for
        each scale.
    """
    points = np.asarray(points, dtype=float)
    if points.ndim != 2:
        raise ValueError(f"Input points must be a 2D array, got shape {points.shape}")
    mins = points.min(axis=0)
    maxs = points.max(axis=0)
    # Prevent zero range by adding a small epsilon when necessary
    ranges = np.where(maxs - mins == 0, 1.0, maxs - mins)
    n_vals = np.array(list(scales), dtype=int)
    D_vals = np.empty(len(n_vals), dtype=float)
    for idx, n in enumerate(n_vals):
        bins_per_dim = 2 ** n
        # Compute integer bin indices along each dimension
        # Normalise points to [0, 1] and multiply by bins_per_dim
        norm = (points - mins) / ranges
        # Clip values to just below 1 to avoid index == bins_per_dim
        norm = np.clip(norm, 0.0, 1.0 - 1e-15)
        inds = np.floor(norm * bins_per_dim).astype(int)
        # Flatten multi‑dimensional indices into a single integer
        # using mixed‑radix representation: idx0 + bins*idx1 + bins^2*idx2 ...
        # This avoids constructing tuples for each row and speeds up uniqueness.
        if inds.shape[1] == 1:
            flat_indices = inds[:, 0]
        else:
            flat_indices = inds[:, 0].copy()
            multiplier = bins_per_dim
            for col in range(1, inds.shape[1]):
                flat_indices += inds[:, col] * multiplier
                multiplier *= bins_per_dim
        # Count unique occupied boxes
        n_occ = np.unique(flat_indices).size
        # Compute dimension estimate
        if n == 0:
            D_vals[idx] = np.nan
        else:
            D_vals[idx] = np.log(float(n_occ)) / (n * np.log(2.0))
    return n_vals, D_vals